{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from typing import Optional\n",
    "import random\n",
    "from diffusers import DDIMScheduler\n",
    "import numpy as np\n",
    "from Freeprompt.diffuser_utils import FreePromptPipeline\n",
    "from Freeprompt.freeprompt_utils import register_attention_control_new\n",
    "from torchvision.utils import save_image\n",
    "from torchvision.io import read_image\n",
    "from Freeprompt.freeprompt import SelfAttentionControlEdit,AttentionStore\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Model Construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n",
      "/apsarapangu/disk1/wushi.lby/anaconda3/envs/ldm/lib/python3.8/site-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n",
      "  warnings.warn(\n",
      "/apsarapangu/disk1/wushi.lby/anaconda3/envs/ldm/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:115: FutureWarning: The configuration file of this scheduler: DDIMScheduler {\n",
      "  \"_class_name\": \"DDIMScheduler\",\n",
      "  \"_diffusers_version\": \"0.16.1\",\n",
      "  \"beta_end\": 0.012,\n",
      "  \"beta_schedule\": \"scaled_linear\",\n",
      "  \"beta_start\": 0.00085,\n",
      "  \"clip_sample\": false,\n",
      "  \"clip_sample_range\": 1.0,\n",
      "  \"dynamic_thresholding_ratio\": 0.995,\n",
      "  \"num_train_timesteps\": 1000,\n",
      "  \"prediction_type\": \"epsilon\",\n",
      "  \"sample_max_value\": 1.0,\n",
      "  \"set_alpha_to_one\": false,\n",
      "  \"steps_offset\": 0,\n",
      "  \"thresholding\": false,\n",
      "  \"trained_betas\": null\n",
      "}\n",
      " is outdated. `steps_offset` should be set to 1 instead of 0. Please make sure to update the config accordingly as leaving `steps_offset` might led to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\n",
      "  deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n"
     ]
    }
   ],
   "source": [
    "# Note that you may add your Hugging Face token to get access to the models\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "model_path = \"runwayml/stable-diffusion-v1-5\"\n",
    "\n",
    "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
    "pipe = FreePromptPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fake image editing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/apsarapangu/disk1/wushi.lby/image_edit/sup/code/Free_prompt_editing/Freeprompt/diffuser_utils.py:134: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.\n",
      "  latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)\n",
      "DDIM Sampler: 100%|██████████| 50/50 [00:10<00:00,  4.68it/s]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Note that you may add your Hugging Face token to get access to the models\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:\n",
    "\n",
    "    if seed is None:\n",
    "        seed = os.environ.get(\"PL_GLOBAL_SEED\")\n",
    "    seed = int(seed)\n",
    "\n",
    "    os.environ[\"PL_GLOBAL_SEED\"] = str(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "    os.environ[\"PL_SEED_WORKERS\"] = f\"{int(workers)}\"\n",
    "\n",
    "    return seed\n",
    "\n",
    "seed = 42\n",
    "seed_everything(seed)\n",
    "\n",
    "self_replace_steps = .8\n",
    "NUM_DIFFUSION_STEPS = 50\n",
    "\n",
    "out_dir = \"examples/outputs\"\n",
    "\n",
    "\n",
    "source_prompt = 'a photo of green ducks walking on street'\n",
    "target_prompt = 'a photo of rubber ducks walking on street'\n",
    "\n",
    "start_code = torch.randn([1, 4, 64, 64], device=device)\n",
    "start_code = start_code.expand(2, -1, -1, -1)\n",
    "\n",
    "latents = torch.randn(start_code.shape, device=device)\n",
    "prompts = [source_prompt, target_prompt]\n",
    "\n",
    "start_code = start_code.expand(len(prompts), -1, -1, -1)\n",
    "controller = SelfAttentionControlEdit(prompts, NUM_DIFFUSION_STEPS, self_replace_steps=self_replace_steps)\n",
    "\n",
    "register_attention_control_new(pipe, controller)\n",
    "\n",
    "results = pipe(prompts,\n",
    "                    latents=start_code,\n",
    "                    guidance_scale=7.5,\n",
    "                    )\n",
    "\n",
    "save_image(results[0], os.path.join(out_dir, str(source_prompt)+'.jpg'))\n",
    "save_image(results[1], os.path.join(out_dir, str(target_prompt)+'.jpg'))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 ('ldm')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "587aa04bacead72c1ffd459abbe4c8140b72ba2b534b24165b36a2ede3d95042"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
